--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements.  See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License.  You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local _M = {}

local mt = {
    __index = _M
}

local CONTENT_TYPE_JSON = "application/json"

local core = require("apisix.core")
local plugin = require("apisix.plugin")
local http = require("resty.http")
local url  = require("socket.url")
local sse  = require("apisix.plugins.ai-drivers.sse")
local ngx  = ngx
local ngx_now = ngx.now

local table = table
local pairs = pairs
local type  = type
local math  = math
local ipairs = ipairs
local setmetatable = setmetatable

local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local HTTP_GATEWAY_TIMEOUT = ngx.HTTP_GATEWAY_TIMEOUT


function _M.new(opts)

    local self = {
        host = opts.host,
        port = opts.port,
        path = opts.path,
        remove_model = opts.options and opts.options.remove_model
    }
    return setmetatable(self, mt)
end


function _M.validate_request(ctx)
        local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON
        if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then
            return nil, "unsupported content-type: " .. ct .. ", only application/json is supported"
        end

        local request_table, err = core.request.get_json_request_body_table()
        if not request_table then
            return nil, err
        end

        return request_table, nil
end


local function handle_error(err)
    if core.string.find(err, "timeout") then
        return HTTP_GATEWAY_TIMEOUT
    end
    return HTTP_INTERNAL_SERVER_ERROR
end


local function read_response(ctx, res)
    local body_reader = res.body_reader
    if not body_reader then
        core.log.warn("AI service sent no response body")
        return HTTP_INTERNAL_SERVER_ERROR
    end

    local content_type = res.headers["Content-Type"]
    core.response.set_header("Content-Type", content_type)

    if content_type and core.string.find(content_type, "text/event-stream") then
        local contents = {}
        while true do
            local chunk, err = body_reader() -- will read chunk by chunk
            ctx.var.apisix_upstream_response_time = math.floor((ngx_now() -
                                             ctx.llm_request_start_time) * 1000)
            if err then
                core.log.warn("failed to read response chunk: ", err)
                return handle_error(err)
            end
            if not chunk then
                return
            end

            if ctx.var.llm_time_to_first_token == "0" then
                ctx.var.llm_time_to_first_token = math.floor(
                                                (ngx_now() - ctx.llm_request_start_time) * 1000)
            end

            local events = sse.decode(chunk)
            ctx.llm_response_contents_in_chunk = {}
            for _, event in ipairs(events) do
                if event.type == "message" then
                    local data, err = core.json.decode(event.data)
                    if not data then
                        core.log.warn("failed to decode SSE data: ", err)
                        goto CONTINUE
                    end

                    if data and type(data.choices) == "table" and #data.choices > 0 then
                        for _, choice in ipairs(data.choices) do
                            if type(choice) == "table"
                                    and type(choice.delta) == "table"
                                    and type(choice.delta.content) == "string" then
                                core.table.insert(contents, choice.delta.content)
                                core.table.insert(ctx.llm_response_contents_in_chunk,
                                                        choice.delta.content)
                            end
                        end
                    end


                    -- usage field is null for non-last events, null is parsed as userdata type
                    if data and type(data.usage) == "table" then
                        core.log.info("got token usage from ai service: ",
                                            core.json.delay_encode(data.usage))
                        ctx.llm_raw_usage = data.usage
                        ctx.ai_token_usage = {
                            prompt_tokens = data.usage.prompt_tokens or 0,
                            completion_tokens = data.usage.completion_tokens or 0,
                            total_tokens = data.usage.total_tokens or 0,
                        }
                        ctx.var.llm_prompt_tokens = ctx.ai_token_usage.prompt_tokens
                        ctx.var.llm_completion_tokens = ctx.ai_token_usage.completion_tokens
                        ctx.var.llm_response_text = table.concat(contents, "")
                    end
                elseif event.type == "done" then
                    ctx.var.llm_request_done = true
                end

                ::CONTINUE::
            end

            plugin.lua_response_filter(ctx, res.headers, chunk)
        end
    end

    local raw_res_body, err = res:read_body()
    if not raw_res_body then
        core.log.warn("failed to read response body: ", err)
        return handle_error(err)
    end
    ngx.status = res.status
    ctx.var.llm_time_to_first_token = math.floor((ngx_now() - ctx.llm_request_start_time) * 1000)
    ctx.var.apisix_upstream_response_time = ctx.var.llm_time_to_first_token
    local res_body, err = core.json.decode(raw_res_body)
    if err then
        core.log.warn("invalid response body from ai service: ", raw_res_body, " err: ", err,
            ", it will cause token usage not available")
    else
        core.log.info("got token usage from ai service: ", core.json.delay_encode(res_body.usage))
        ctx.ai_token_usage = {}
        if type(res_body.usage) == "table" then
            ctx.llm_raw_usage = res_body.usage
            ctx.ai_token_usage.prompt_tokens = res_body.usage.prompt_tokens or 0
            ctx.ai_token_usage.completion_tokens = res_body.usage.completion_tokens or 0
            ctx.ai_token_usage.total_tokens = res_body.usage.total_tokens or 0
        end
        ctx.var.llm_prompt_tokens = ctx.ai_token_usage.prompt_tokens or 0
        ctx.var.llm_completion_tokens = ctx.ai_token_usage.completion_tokens or 0
        if type(res_body.choices) == "table" and #res_body.choices > 0 then
            local contents = {}
            for _, choice in ipairs(res_body.choices) do
                if type(choice) == "table"
                        and type(choice.message) == "table"
                        and type(choice.message.content) == "string" then
                    core.table.insert(contents, choice.message.content)
                end
            end
            local content_to_check = table.concat(contents, " ")
            ctx.var.llm_response_text = content_to_check
        end
    end
    plugin.lua_response_filter(ctx, res.headers, raw_res_body)
end


function _M.request(self, ctx, conf, request_table, extra_opts)
    local httpc, err = http.new()
    if not httpc then
        core.log.error("failed to create http client to send request to LLM server: ", err)
        return HTTP_INTERNAL_SERVER_ERROR
    end
    httpc:set_timeout(conf.timeout)

    local endpoint = extra_opts and extra_opts.endpoint
    local parsed_url
    if endpoint then
        parsed_url = url.parse(endpoint)
    end

    local scheme = parsed_url and parsed_url.scheme or "https"
    local host = parsed_url and parsed_url.host or self.host
    local port = parsed_url and parsed_url.port
    if not port then
        if scheme == "https" then
            port = 443
        else
            port = 80
        end
    end
    local ok, err = httpc:connect({
        scheme = scheme,
        host = host,
        port = port,
        ssl_verify = conf.ssl_verify,
        ssl_server_name = parsed_url and parsed_url.host or self.host,
    })

    if not ok then
        core.log.warn("failed to connect to LLM server: ", err)
        return handle_error(err)
    end

    local query_params = extra_opts.query_params

    if type(parsed_url) == "table" and parsed_url.query and #parsed_url.query > 0 then
        local args_tab = core.string.decode_args(parsed_url.query)
        if type(args_tab) == "table" then
            core.table.merge(query_params, args_tab)
        end
    end

    local path = (parsed_url and parsed_url.path or self.path)

    local headers = extra_opts.headers
    headers["Content-Type"] = "application/json"
    local params = {
        method = "POST",
        headers = headers,
        ssl_verify = conf.ssl_verify,
        path = path,
        query = query_params
    }

    if extra_opts.model_options then
        for opt, val in pairs(extra_opts.model_options) do
            request_table[opt] = val
        end
    end
    if self.remove_model then
        request_table.model = nil
    end
    local req_json, err = core.json.encode(request_table)
    if not req_json then
        return nil, err
    end

    params.body = req_json

    local res, err = httpc:request(params)
    if not res then
        core.log.warn("failed to send request to LLM server: ", err)
        return handle_error(err)
    end

    -- handling this error separately is needed for retries
    if res.status == 429 or (res.status >= 500 and res.status < 600 )then
        return res.status
    end

    local code, body = read_response(ctx, res)

    if conf.keepalive then
        local ok, err = httpc:set_keepalive(conf.keepalive_timeout, conf.keepalive_pool)
        if not ok then
            core.log.warn("failed to keepalive connection: ", err)
        end
    end

    return code, body
end


return _M
